# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os

import numpy as np
import paddle
from paddle.io import BatchSampler
from paddle.io import DataLoader
from tqdm import tqdm
from yacs.config import CfgNode

from paddlespeech.audio.metric import compute_eer
from paddlespeech.s2t.utils.log import Log
from paddlespeech.vector.io.batch import batch_feature_normalize
from paddlespeech.vector.io.dataset import CSVDataset
from paddlespeech.vector.io.embedding_norm import InputNormalization
from paddlespeech.vector.models.ecapa_tdnn import EcapaTdnn
from paddlespeech.vector.modules.sid_model import SpeakerIdetification
from paddlespeech.vector.training.seeding import seed_everything

logger = Log(__name__).getlog()


def compute_dataset_embedding(data_loader, model, mean_var_norm_emb, config,
                              id2embedding):
    """compute the dataset embeddings

    Args:
        data_loader (paddle.io.Dataloader): the dataset loader to be compute the embedding
        model (paddle.nn.Layer): the speaker verification model
        mean_var_norm_emb : compute the embedding mean and std norm
        config (yacs.config.CfgNode): the yaml config
    """
    logger.info(
        f'Computing embeddings on {data_loader.dataset.csv_path} dataset')
    with paddle.no_grad():
        for batch_idx, batch in enumerate(tqdm(data_loader)):

            # stage 8-1: extrac the audio embedding
            ids, feats, lengths = batch['ids'], batch['feats'], batch['lengths']
            embeddings = model.backbone(feats, lengths).squeeze(
                -1)  # (N, emb_size, 1) -> (N, emb_size)

            # Global embedding normalization.
            # if we use the global embedding norm
            # eer can reduece about relative 10%
            if config.global_embedding_norm and mean_var_norm_emb:
                lengths = paddle.ones([embeddings.shape[0]])
                embeddings = mean_var_norm_emb(embeddings, lengths)

            # Update embedding dict.
            id2embedding.update(dict(zip(ids, embeddings)))


def compute_verification_scores(id2embedding, train_cohort, config):
    """Compute the verification trial scores

    Args:
        id2embedding (dict): the utterance embedding
        train_cohort (paddle.tensor): the cohort dataset embedding
        config (yacs.config.CfgNode): the yaml config

    Returns:
        the scores and the trial labels, 
        1 refers the target and 0 refers the nontarget in labels
    """
    labels = []
    enroll_ids = []
    test_ids = []
    logger.info(f"read the trial from {config.verification_file}")
    cos_sim_func = paddle.nn.CosineSimilarity(axis=-1)
    scores = []
    with open(config.verification_file, 'r') as f:
        for line in f.readlines():
            label, enroll_id, test_id = line.strip().split(' ')
            enroll_id = enroll_id.split('.')[0].replace('/', '-')
            test_id = test_id.split('.')[0].replace('/', '-')
            labels.append(int(label))

            enroll_emb = id2embedding[enroll_id]
            test_emb = id2embedding[test_id]
            score = cos_sim_func(enroll_emb, test_emb).item()

            if "score_norm" in config:
                # Getting norm stats for enroll impostors
                enroll_rep = paddle.tile(
                    enroll_emb, repeat_times=[train_cohort.shape[0], 1])
                score_e_c = cos_sim_func(enroll_rep, train_cohort)
                if "cohort_size" in config:
                    score_e_c, _ = paddle.topk(
                        score_e_c, k=config.cohort_size, axis=0)
                mean_e_c = paddle.mean(score_e_c, axis=0)
                std_e_c = paddle.std(score_e_c, axis=0)

                # Getting norm stats for test impostors
                test_rep = paddle.tile(
                    test_emb, repeat_times=[train_cohort.shape[0], 1])
                score_t_c = cos_sim_func(test_rep, train_cohort)
                if "cohort_size" in config:
                    score_t_c, _ = paddle.topk(
                        score_t_c, k=config.cohort_size, axis=0)
                mean_t_c = paddle.mean(score_t_c, axis=0)
                std_t_c = paddle.std(score_t_c, axis=0)

                if config.score_norm == "s-norm":
                    score_e = (score - mean_e_c) / std_e_c
                    score_t = (score - mean_t_c) / std_t_c

                    score = 0.5 * (score_e + score_t)
                elif config.score_norm == "z-norm":
                    score = (score - mean_e_c) / std_e_c
                elif config.score_norm == "t-norm":
                    score = (score - mean_t_c) / std_t_c

            scores.append(score)

    return scores, labels


def main(args, config):
    """The main process for test the speaker verification model

    Args:
        args (argparse.Namespace): the command line args namespace
        config (yacs.config.CfgNode): the yaml config
    """

    # stage0: set the training device, cpu or gpu
    #         if set the gpu, paddlespeech will select a gpu according the env CUDA_VISIBLE_DEVICES
    paddle.set_device(args.device)
    # set the random seed, it is the necessary measures for multiprocess training
    seed_everything(config.seed)

    # stage1: build the dnn backbone model network
    #         we will extract the audio embedding from the backbone model
    ecapa_tdnn = EcapaTdnn(**config.model)

    # stage2: build the speaker verification eval instance with backbone model
    #         because the checkpoint dict name has the SpeakerIdetification prefix
    #         so we need to create the SpeakerIdetification instance
    #         but we acutally use the backbone model to extact the audio embedding 
    model = SpeakerIdetification(
        backbone=ecapa_tdnn, num_class=config.num_speakers)

    # stage3: load the pre-trained model
    #         generally, we get the last model from the epoch
    args.load_checkpoint = os.path.abspath(
        os.path.expanduser(args.load_checkpoint))

    # load model checkpoint to sid model
    state_dict = paddle.load(
        os.path.join(args.load_checkpoint, 'model.pdparams'))
    model.set_state_dict(state_dict)
    logger.info(f'Checkpoint loaded from {args.load_checkpoint}')

    # stage4: construct the enroll and test dataloader
    #         Now, wo think the enroll dataset is in the {args.data_dir}/vox/csv/enroll.csv,
    #         and the test dataset is in the {args.data_dir}/vox/csv/test.csv
    enroll_dataset = CSVDataset(
        os.path.join(args.data_dir, "vox/csv/enroll.csv"),
        feat_type='melspectrogram',
        random_chunk=False,
        n_mels=config.n_mels,
        window_size=config.window_size,
        hop_length=config.hop_size)
    enroll_sampler = BatchSampler(
        enroll_dataset, batch_size=config.batch_size, shuffle=False)
    enroll_loader = DataLoader(enroll_dataset,
                    batch_sampler=enroll_sampler,
                    collate_fn=lambda x: batch_feature_normalize(
                                x, mean_norm=True, std_norm=False),
                    num_workers=config.num_workers,
                    return_list=True,)

    test_dataset = CSVDataset(
        os.path.join(args.data_dir, "vox/csv/test.csv"),
        feat_type='melspectrogram',
        random_chunk=False,
        n_mels=config.n_mels,
        window_size=config.window_size,
        hop_length=config.hop_size)
    test_sampler = BatchSampler(
        test_dataset, batch_size=config.batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset,
                            batch_sampler=test_sampler,
                            collate_fn=lambda x: batch_feature_normalize(
                                x, mean_norm=True, std_norm=False),
                            num_workers=config.num_workers,
                            return_list=True,)
    # stage5: we must set the model to eval mode
    model.eval()

    # stage6: global embedding norm to imporve the performance
    #         and we create the InputNormalization instance to process the embedding mean and std norm
    logger.info(f"global embedding norm: {config.global_embedding_norm}")
    if config.global_embedding_norm:
        mean_var_norm_emb = InputNormalization(
            norm_type="global",
            mean_norm=config.embedding_mean_norm,
            std_norm=config.embedding_std_norm)

    # stage 7: score norm need the imposters dataset
    #          we select the train dataset as the idea imposters dataset
    #          and we select the config.n_train_snts utterance to as the final imposters dataset
    if "score_norm" in config:
        logger.info(f"we will do score norm: {config.score_norm}")
        train_dataset = CSVDataset(
            os.path.join(args.data_dir, "vox/csv/train.csv"),
            feat_type='melspectrogram',
            n_train_snts=config.n_train_snts,
            random_chunk=False,
            n_mels=config.n_mels,
            window_size=config.window_size,
            hop_length=config.hop_size)
        train_sampler = BatchSampler(
            train_dataset, batch_size=config.batch_size, shuffle=False)
        train_loader = DataLoader(train_dataset,
                            batch_sampler=train_sampler,
                            collate_fn=lambda x: batch_feature_normalize(
                                x, mean_norm=True, std_norm=False),
                            num_workers=config.num_workers,
                            return_list=True,)

    # stage 8: Compute embeddings of audios in enrol and test dataset from model.
    id2embedding = {}
    # Run multi times to make embedding normalization more stable.
    logger.info("First loop for enroll and test dataset")
    compute_dataset_embedding(enroll_loader, model, mean_var_norm_emb, config,
                              id2embedding)
    compute_dataset_embedding(test_loader, model, mean_var_norm_emb, config,
                              id2embedding)

    logger.info("Second loop for enroll and test dataset")
    compute_dataset_embedding(enroll_loader, model, mean_var_norm_emb, config,
                              id2embedding)
    compute_dataset_embedding(test_loader, model, mean_var_norm_emb, config,
                              id2embedding)
    mean_var_norm_emb.save(
        os.path.join(args.load_checkpoint, "mean_var_norm_emb"))

    # stage 9: Compute cosine scores.
    train_cohort = None
    if "score_norm" in config:
        train_embeddings = {}
        # cohort embedding not do mean and std norm
        compute_dataset_embedding(train_loader, model, None, config,
                                  train_embeddings)
        train_cohort = paddle.stack(list(train_embeddings.values()))

    # stage 10: compute the scores
    scores, labels = compute_verification_scores(id2embedding, train_cohort,
                                                 config)

    # stage 11: compute the EER and threshold
    scores = paddle.to_tensor(scores)
    EER, threshold = compute_eer(np.asarray(labels), scores.numpy())
    logger.info(
        f'EER of verification test: {EER*100:.4f}%, score threshold: {threshold:.5f}'
    )


if __name__ == "__main__":
    # yapf: disable
    parser = argparse.ArgumentParser(__doc__)
    parser.add_argument('--device',
                        choices=['cpu', 'gpu'],
                        default="gpu",
                        help="Select which device to train model, defaults to gpu.")
    parser.add_argument("--config",
                        default=None,
                        type=str,
                        help="configuration file")
    parser.add_argument("--data-dir",
                        default="./data/",
                        type=str,
                        help="data directory")
    parser.add_argument("--load-checkpoint",
                        type=str,
                        default='',
                        help="Directory to load model checkpoint to contiune trainning.")
    args = parser.parse_args()
    # yapf: enable
    # https://yaml.org/type/float.html
    config = CfgNode(new_allowed=True)
    if args.config:
        config.merge_from_file(args.config)

    config.freeze()
    print(config)
    main(args, config)
